{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Autoencoder Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from scipy.stats import norm\n",
    "\n",
    "from models.AE import Autoencoder\n",
    "from utils.loaders import load_mnist, load_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run params\n",
    "SECTION = 'vae'\n",
    "RUN_ID = '0001'\n",
    "DATA_NAME = 'digits'\n",
    "RUN_FOLDER = 'run/{}/'.format(SECTION)\n",
    "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_test, y_test) = load_mnist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the model architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AE = load_model(Autoencoder, RUN_FOLDER)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reconstructing original paintings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_to_show = 10\n",
    "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n",
    "example_images = x_test[example_idx]\n",
    "\n",
    "z_points = AE.encoder.predict(example_images)\n",
    "\n",
    "reconst_images = AE.decoder.predict(z_points)\n",
    "\n",
    "fig = plt.figure(figsize=(15, 3))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "\n",
    "for i in range(n_to_show):\n",
    "    img = example_images[i].squeeze()\n",
    "    ax = fig.add_subplot(2, n_to_show, i+1)\n",
    "    ax.axis('off')\n",
    "    ax.text(0.5, -0.35, str(np.round(z_points[i],1)), fontsize=10, ha='center', transform=ax.transAxes)   \n",
    "    ax.imshow(img, cmap='gray_r')\n",
    "\n",
    "for i in range(n_to_show):\n",
    "    img = reconst_images[i].squeeze()\n",
    "    ax = fig.add_subplot(2, n_to_show, i+n_to_show+1)\n",
    "    ax.axis('off')\n",
    "    ax.imshow(img, cmap='gray_r')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mr N. Coder's wall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_to_show = 5000\n",
    "grid_size = 15\n",
    "figsize = 12\n",
    "\n",
    "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n",
    "example_images = x_test[example_idx]\n",
    "example_labels = y_test[example_idx]\n",
    "\n",
    "z_points = AE.encoder.predict(example_images)\n",
    "\n",
    "min_x = min(z_points[:, 0])\n",
    "max_x = max(z_points[:, 0])\n",
    "min_y = min(z_points[:, 1])\n",
    "max_y = max(z_points[:, 1])\n",
    "\n",
    "plt.figure(figsize=(figsize, figsize))\n",
    "plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The new generated art exhibition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figsize = 5\n",
    "\n",
    "plt.figure(figsize=(figsize, figsize))\n",
    "plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)\n",
    "\n",
    "grid_size = 10\n",
    "grid_depth = 3\n",
    "figsize = 15\n",
    "\n",
    "x = np.random.uniform(min_x,max_x, size = grid_size * grid_depth)\n",
    "y = np.random.uniform(min_y,max_y, size = grid_size * grid_depth)\n",
    "z_grid = np.array(list(zip(x, y)))\n",
    "reconst = AE.decoder.predict(z_grid)\n",
    "\n",
    "plt.scatter(z_grid[:, 0] , z_grid[:, 1], c = 'red', alpha=1, s=20)\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(figsize, grid_depth))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "\n",
    "for i in range(grid_size*grid_depth):\n",
    "    ax = fig.add_subplot(grid_depth, grid_size, i+1)\n",
    "    ax.axis('off')\n",
    "    ax.text(0.5, -0.35, str(np.round(z_grid[i],1)), fontsize=10, ha='center', transform=ax.transAxes)\n",
    "    \n",
    "    ax.imshow(reconst[i, :,:,0], cmap = 'Greys')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_to_show = 5000\n",
    "grid_size = 15\n",
    "figsize = 12\n",
    "\n",
    "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n",
    "example_images = x_test[example_idx]\n",
    "example_labels = y_test[example_idx]\n",
    "\n",
    "z_points = AE.encoder.predict(example_images)\n",
    "\n",
    "plt.figure(figsize=(figsize, figsize))\n",
    "plt.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels\n",
    "            , alpha=0.5, s=2)\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_to_show = 5000\n",
    "grid_size = 20\n",
    "figsize = 8\n",
    "\n",
    "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n",
    "example_images = x_test[example_idx]\n",
    "example_labels = y_test[example_idx]\n",
    "\n",
    "z_points = AE.encoder.predict(example_images)\n",
    "\n",
    "plt.figure(figsize=(5, 5))\n",
    "plt.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels\n",
    "            , alpha=0.5, s=2)\n",
    "plt.colorbar()\n",
    "\n",
    "# x = norm.ppf(np.linspace(0.05, 0.95, 10))\n",
    "# y = norm.ppf(np.linspace(0.05, 0.95, 10))\n",
    "x = np.linspace(min(z_points[:, 0]), max(z_points[:, 0]), grid_size)\n",
    "y = np.linspace(max(z_points[:, 1]), min(z_points[:, 1]), grid_size)\n",
    "xv, yv = np.meshgrid(x, y)\n",
    "xv = xv.flatten()\n",
    "yv = yv.flatten()\n",
    "z_grid = np.array(list(zip(xv, yv)))\n",
    "\n",
    "reconst = AE.decoder.predict(z_grid)\n",
    "\n",
    "plt.scatter(z_grid[:, 0] , z_grid[:, 1], c = 'black'#, cmap='rainbow' , c= example_labels\n",
    "            , alpha=1, s=5)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.show()\n",
    "\n",
    "\n",
    "fig = plt.figure(figsize=(figsize, figsize))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "for i in range(grid_size**2):\n",
    "    ax = fig.add_subplot(grid_size, grid_size, i+1)\n",
    "    ax.axis('off')\n",
    "    ax.imshow(reconst[i, :,:,0], cmap = 'Greys')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdl_code",
   "language": "python",
   "name": "gdl_code"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
